{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "267d5504009eca2b809a2691131366baebe98436"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "import datetime\n",
    "import random\n",
    "from PIL import Image\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "fc376df46433261bfeb643a95793718a9d969ed1"
   },
   "outputs": [],
   "source": [
    "def generator(z, output_channel_dim, training):\n",
    "    with tf.variable_scope(\"generator\", reuse= not training):\n",
    "        \n",
    "        # 8x8x1024\n",
    "        fully_connected = tf.layers.dense(z, 8*8*1024)\n",
    "        fully_connected = tf.reshape(fully_connected, (-1, 8, 8, 1024))\n",
    "        fully_connected = tf.nn.leaky_relu(fully_connected)\n",
    "\n",
    "        # 8x8x1024 -> 16x16x512\n",
    "        trans_conv1 = tf.layers.conv2d_transpose(inputs=fully_connected,\n",
    "                                                 filters=512,\n",
    "                                                 kernel_size=[5,5],\n",
    "                                                 strides=[2,2],\n",
    "                                                 padding=\"SAME\",\n",
    "                                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                                 name=\"trans_conv1\")\n",
    "        batch_trans_conv1 = tf.layers.batch_normalization(inputs = trans_conv1,\n",
    "                                                          training=training,\n",
    "                                                          epsilon=EPSILON,\n",
    "                                                          name=\"batch_trans_conv1\")\n",
    "        trans_conv1_out = tf.nn.leaky_relu(batch_trans_conv1,\n",
    "                                           name=\"trans_conv1_out\")\n",
    "        \n",
    "        # 16x16x512 -> 32x32x256\n",
    "        trans_conv2 = tf.layers.conv2d_transpose(inputs=trans_conv1_out,\n",
    "                                                 filters=256,\n",
    "                                                 kernel_size=[5,5],\n",
    "                                                 strides=[2,2],\n",
    "                                                 padding=\"SAME\",\n",
    "                                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                                 name=\"trans_conv2\")\n",
    "        batch_trans_conv2 = tf.layers.batch_normalization(inputs = trans_conv2,\n",
    "                                                          training=training,\n",
    "                                                          epsilon=EPSILON,\n",
    "                                                          name=\"batch_trans_conv2\")\n",
    "        trans_conv2_out = tf.nn.leaky_relu(batch_trans_conv2,\n",
    "                                           name=\"trans_conv2_out\")\n",
    "        \n",
    "        # 32x32x256 -> 64x64x128\n",
    "        trans_conv3 = tf.layers.conv2d_transpose(inputs=trans_conv2_out,\n",
    "                                                 filters=128,\n",
    "                                                 kernel_size=[5,5],\n",
    "                                                 strides=[2,2],\n",
    "                                                 padding=\"SAME\",\n",
    "                                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                                 name=\"trans_conv3\")\n",
    "        batch_trans_conv3 = tf.layers.batch_normalization(inputs = trans_conv3,\n",
    "                                                          training=training,\n",
    "                                                          epsilon=EPSILON,\n",
    "                                                          name=\"batch_trans_conv3\")\n",
    "        trans_conv3_out = tf.nn.leaky_relu(batch_trans_conv3,\n",
    "                                           name=\"trans_conv3_out\")\n",
    "        \n",
    "        # 64x64x128 -> 128x128x64\n",
    "        trans_conv4 = tf.layers.conv2d_transpose(inputs=trans_conv3_out,\n",
    "                                                 filters=64,\n",
    "                                                 kernel_size=[5,5],\n",
    "                                                 strides=[2,2],\n",
    "                                                 padding=\"SAME\",\n",
    "                                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                                 name=\"trans_conv4\")\n",
    "        batch_trans_conv4 = tf.layers.batch_normalization(inputs = trans_conv4,\n",
    "                                                          training=training,\n",
    "                                                          epsilon=EPSILON,\n",
    "                                                          name=\"batch_trans_conv4\")\n",
    "        trans_conv4_out = tf.nn.leaky_relu(batch_trans_conv4,\n",
    "                                           name=\"trans_conv4_out\")\n",
    "        \n",
    "        # 128x128x64 -> 128x128x3\n",
    "        logits = tf.layers.conv2d_transpose(inputs=trans_conv4_out,\n",
    "                                            filters=3,\n",
    "                                            kernel_size=[5,5],\n",
    "                                            strides=[1,1],\n",
    "                                            padding=\"SAME\",\n",
    "                                            kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                            name=\"logits\")\n",
    "        out = tf.tanh(logits, name=\"out\")\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "ba53a4bb09dbcd57d3e3392f74ccd054ecf23ecb"
   },
   "outputs": [],
   "source": [
    "def discriminator(x, reuse):\n",
    "    with tf.variable_scope(\"discriminator\", reuse=reuse): \n",
    "        \n",
    "        # 128*128*3 -> 64x64x64 \n",
    "        conv1 = tf.layers.conv2d(inputs=x,\n",
    "                                 filters=64,\n",
    "                                 kernel_size=[5,5],\n",
    "                                 strides=[2,2],\n",
    "                                 padding=\"SAME\",\n",
    "                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                 name='conv1')\n",
    "        batch_norm1 = tf.layers.batch_normalization(conv1,\n",
    "                                                    training=True,\n",
    "                                                    epsilon=EPSILON,\n",
    "                                                    name='batch_norm1')\n",
    "        conv1_out = tf.nn.leaky_relu(batch_norm1,\n",
    "                                     name=\"conv1_out\")\n",
    "        \n",
    "        # 64x64x64-> 32x32x128 \n",
    "        conv2 = tf.layers.conv2d(inputs=conv1_out,\n",
    "                                 filters=128,\n",
    "                                 kernel_size=[5, 5],\n",
    "                                 strides=[2, 2],\n",
    "                                 padding=\"SAME\",\n",
    "                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                 name='conv2')\n",
    "        batch_norm2 = tf.layers.batch_normalization(conv2,\n",
    "                                                    training=True,\n",
    "                                                    epsilon=EPSILON,\n",
    "                                                    name='batch_norm2')\n",
    "        conv2_out = tf.nn.leaky_relu(batch_norm2,\n",
    "                                     name=\"conv2_out\")\n",
    "        \n",
    "        # 32x32x128 -> 16x16x256  \n",
    "        conv3 = tf.layers.conv2d(inputs=conv2_out,\n",
    "                                 filters=256,\n",
    "                                 kernel_size=[5, 5],\n",
    "                                 strides=[2, 2],\n",
    "                                 padding=\"SAME\",\n",
    "                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                 name='conv3')\n",
    "        batch_norm3 = tf.layers.batch_normalization(conv3,\n",
    "                                                    training=True,\n",
    "                                                    epsilon=EPSILON,\n",
    "                                                    name='batch_norm3')\n",
    "        conv3_out = tf.nn.leaky_relu(batch_norm3,\n",
    "                                     name=\"conv3_out\")\n",
    "        \n",
    "        # 16x16x256 -> 16x16x512\n",
    "        conv4 = tf.layers.conv2d(inputs=conv3_out,\n",
    "                                 filters=512,\n",
    "                                 kernel_size=[5, 5],\n",
    "                                 strides=[1, 1],\n",
    "                                 padding=\"SAME\",\n",
    "                                 kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                 name='conv4')\n",
    "        batch_norm4 = tf.layers.batch_normalization(conv4,\n",
    "                                                    training=True,\n",
    "                                                    epsilon=EPSILON,\n",
    "                                                    name='batch_norm4')\n",
    "        conv4_out = tf.nn.leaky_relu(batch_norm4,\n",
    "                                     name=\"conv4_out\")\n",
    "        \n",
    "        # 16x16x512 -> 8x8x1024\n",
    "        conv5 = tf.layers.conv2d(inputs=conv4_out,\n",
    "                                filters=1024,\n",
    "                                kernel_size=[5, 5],\n",
    "                                strides=[2, 2],\n",
    "                                padding=\"SAME\",\n",
    "                                kernel_initializer=tf.truncated_normal_initializer(stddev=WEIGHT_INIT_STDDEV),\n",
    "                                name='conv5')\n",
    "        batch_norm5 = tf.layers.batch_normalization(conv5,\n",
    "                                                    training=True,\n",
    "                                                    epsilon=EPSILON,\n",
    "                                                    name='batch_norm5')\n",
    "        conv5_out = tf.nn.leaky_relu(batch_norm5,\n",
    "                                     name=\"conv5_out\")\n",
    "\n",
    "        flatten = tf.reshape(conv5_out, (-1, 8*8*1024))\n",
    "        logits = tf.layers.dense(inputs=flatten,\n",
    "                                 units=1,\n",
    "                                 activation=None)\n",
    "        out = tf.sigmoid(logits)\n",
    "        return out, logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "8ef5bbb8e4d157577f1b15600aa64cb40289a754"
   },
   "outputs": [],
   "source": [
    "def model_loss(input_real, input_z, output_channel_dim):\n",
    "    g_model = generator(input_z, output_channel_dim, True)\n",
    "\n",
    "    noisy_input_real = input_real + tf.random_normal(shape=tf.shape(input_real),\n",
    "                                                     mean=0.0,\n",
    "                                                     stddev=random.uniform(0.0, 0.1),\n",
    "                                                     dtype=tf.float32)\n",
    "    \n",
    "    d_model_real, d_logits_real = discriminator(noisy_input_real, reuse=False)\n",
    "    d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)\n",
    "    \n",
    "    d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_real,\n",
    "                                                                         labels=tf.ones_like(d_model_real)*random.uniform(0.9, 1.0)))\n",
    "    d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,\n",
    "                                                                         labels=tf.zeros_like(d_model_fake)))\n",
    "    d_loss = tf.reduce_mean(0.5 * (d_loss_real + d_loss_fake))\n",
    "    g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=d_logits_fake,\n",
    "                                                                    labels=tf.ones_like(d_model_fake)))\n",
    "    return d_loss, g_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "01820b197de1b6d0ca39592043308557d941937a"
   },
   "outputs": [],
   "source": [
    "def model_optimizers(d_loss, g_loss):\n",
    "    t_vars = tf.trainable_variables()\n",
    "    g_vars = [var for var in t_vars if var.name.startswith(\"generator\")]\n",
    "    d_vars = [var for var in t_vars if var.name.startswith(\"discriminator\")]\n",
    "    \n",
    "    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)\n",
    "    gen_updates = [op for op in update_ops if op.name.startswith('generator')]\n",
    "    \n",
    "    with tf.control_dependencies(gen_updates):\n",
    "        d_train_opt = tf.train.AdamOptimizer(learning_rate=LR_D, beta1=BETA1).minimize(d_loss, var_list=d_vars)\n",
    "        g_train_opt = tf.train.AdamOptimizer(learning_rate=LR_G, beta1=BETA1).minimize(g_loss, var_list=g_vars)  \n",
    "    return d_train_opt, g_train_opt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model_inputs(real_dim, z_dim):\n",
    "    inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='inputs_real')\n",
    "    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name=\"input_z\")\n",
    "    learning_rate_G = tf.placeholder(tf.float32, name=\"lr_g\")\n",
    "    learning_rate_D = tf.placeholder(tf.float32, name=\"lr_d\")\n",
    "    return inputs_real, inputs_z, learning_rate_G, learning_rate_D"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "3a7808bacf25966ab5a8ca06e2e1075ac8eeb662"
   },
   "outputs": [],
   "source": [
    "def show_samples(sample_images, name, epoch):\n",
    "    figure, axes = plt.subplots(1, len(sample_images), figsize = (IMAGE_SIZE, IMAGE_SIZE))\n",
    "    for index, axis in enumerate(axes):\n",
    "        axis.axis('off')\n",
    "        image_array = sample_images[index]\n",
    "        axis.imshow(image_array)\n",
    "        image = Image.fromarray(image_array)\n",
    "        image.save(name+\"_\"+str(epoch)+\"_\"+str(index)+\".png\") \n",
    "    plt.savefig(name+\"_\"+str(epoch)+\".png\", bbox_inches='tight', pad_inches=0)\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(sess, input_z, out_channel_dim, epoch):\n",
    "    example_z = np.random.uniform(-1, 1, size=[SAMPLES_TO_SHOW, input_z.get_shape().as_list()[-1]])\n",
    "    samples = sess.run(generator(input_z, out_channel_dim, False), feed_dict={input_z: example_z})\n",
    "    sample_images = [((sample + 1.0) * 127.5).astype(np.uint8) for sample in samples]\n",
    "    show_samples(sample_images, OUTPUT_DIR + \"samples\", epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "0ed8f1c378f936ac81ae89b35d5b6914cd6efbbb"
   },
   "outputs": [],
   "source": [
    "def summarize_epoch(epoch, duration, sess, d_losses, g_losses, input_z, data_shape):\n",
    "    minibatch_size = int(data_shape[0]//BATCH_SIZE)\n",
    "    print(\"Epoch {}/{}\".format(epoch, EPOCHS),\n",
    "          \"\\nDuration: {:.5f}\".format(duration),\n",
    "          \"\\nD Loss: {:.5f}\".format(np.mean(d_losses[-minibatch_size:])),\n",
    "          \"\\nG Loss: {:.5f}\".format(np.mean(g_losses[-minibatch_size:])))\n",
    "    fig, ax = plt.subplots()\n",
    "    plt.plot(d_losses, label='Discriminator', alpha=0.6)\n",
    "    plt.plot(g_losses, label='Generator', alpha=0.6)\n",
    "    plt.title(\"Losses\")\n",
    "    plt.legend()\n",
    "    plt.savefig(OUTPUT_DIR + \"losses_\" + str(epoch) + \".png\")\n",
    "    plt.show()\n",
    "    plt.close()\n",
    "    test(sess, input_z, data_shape[3], epoch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_batches(data):\n",
    "    batches = []\n",
    "    for i in range(int(data.shape[0]//BATCH_SIZE)):\n",
    "        batch = data[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]\n",
    "        augmented_images = []\n",
    "        for img in batch:\n",
    "            image = Image.fromarray(img)\n",
    "            if random.choice([True, False]):\n",
    "                image = image.transpose(Image.FLIP_LEFT_RIGHT)\n",
    "            augmented_images.append(np.asarray(image))\n",
    "        batch = np.asarray(augmented_images)\n",
    "        normalized_batch = (batch / 127.5) - 1.0\n",
    "        batches.append(normalized_batch)\n",
    "    return batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(get_batches, data_shape, checkpoint_to_load=None):\n",
    "    input_images, input_z, lr_G, lr_D = model_inputs(data_shape[1:], NOISE_SIZE)\n",
    "    d_loss, g_loss = model_loss(input_images, input_z, data_shape[3])\n",
    "    d_opt, g_opt = model_optimizers(d_loss, g_loss)\n",
    "    \n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        epoch = 0\n",
    "        iteration = 0\n",
    "        d_losses = []\n",
    "        g_losses = []\n",
    "        \n",
    "        for epoch in range(EPOCHS):        \n",
    "            epoch += 1\n",
    "            start_time = time.time()\n",
    "\n",
    "            for batch_images in get_batches:\n",
    "                iteration += 1\n",
    "                batch_z = np.random.uniform(-1, 1, size=(BATCH_SIZE, NOISE_SIZE))\n",
    "                _ = sess.run(d_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_D: LR_D})\n",
    "                _ = sess.run(g_opt, feed_dict={input_images: batch_images, input_z: batch_z, lr_G: LR_G})\n",
    "                d_losses.append(d_loss.eval({input_z: batch_z, input_images: batch_images}))\n",
    "                g_losses.append(g_loss.eval({input_z: batch_z}))\n",
    "\n",
    "            summarize_epoch(epoch, time.time()-start_time, sess, d_losses, g_losses, input_z, data_shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "d857345ed6c3ba7e4f614fa90b5ca42adb9917cf"
   },
   "outputs": [],
   "source": [
    "# Paths\n",
    "INPUT_DATA_DIR = \"\" # Path to the folder with input images. For more info check simspons_dataset.txt\n",
    "OUTPUT_DIR = './{date:%Y-%m-%d_%H:%M:%S}/'.format(date=datetime.datetime.now())\n",
    "if not os.path.exists(OUTPUT_DIR):\n",
    "    os.makedirs(OUTPUT_DIR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "IMAGE_SIZE = 128\n",
    "NOISE_SIZE = 100\n",
    "LR_D = 0.00004\n",
    "LR_G = 0.0004\n",
    "BATCH_SIZE = 64\n",
    "EPOCHS = 300\n",
    "BETA1 = 0.5\n",
    "WEIGHT_INIT_STDDEV = 0.02\n",
    "EPSILON = 0.00005\n",
    "SAMPLES_TO_SHOW = 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_uuid": "1ee6349afccb4d2f29f36d6605dd2f156350821a",
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# Training\n",
    "input_images = np.asarray([np.asarray(Image.open(file).resize((IMAGE_SIZE, IMAGE_SIZE))) for file in glob(INPUT_DATA_DIR + '*')])\n",
    "print (\"Input: \" + str(input_images.shape))\n",
    "\n",
    "np.random.shuffle(input_images)\n",
    "\n",
    "sample_images = random.sample(list(input_images), SAMPLES_TO_SHOW)\n",
    "show_samples(sample_images, OUTPUT_DIR + \"inputs\", 0)\n",
    "\n",
    "with tf.Graph().as_default():\n",
    "    train(get_batches(input_images), input_images.shape)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
